{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multi-Task Learning Example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a simple example to show how to use mxnet for multi-task learning.\n",
    "\n",
    "The network is jointly going to learn whether a number is odd or even and to actually recognize the digit.\n",
    "\n",
    "\n",
    "For example\n",
    "\n",
    "- 1 : 1 and odd\n",
    "- 2 : 2 and even\n",
    "- 3 : 3 and odd\n",
    "\n",
    "etc\n",
    "\n",
    "In this example we don't expect the tasks to contribute to each other much, but for example multi-task learning has been successfully applied to the domain of image captioning. In [A Multi-task Learning Approach for Image Captioning](https://www.ijcai.org/proceedings/2018/0168.pdf) by Wei Zhao, Benyou Wang, Jianbo Ye, Min Yang, Zhou Zhao, Ruotian Luo, Yu Qiao, they train a network to jointly classify images and generate text captions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import random\n",
    "import time\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import mxnet as mx\n",
    "from mxnet import gluon, nd, autograd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "epochs = 5\n",
    "ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()\n",
    "lr = 0.01"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data\n",
    "\n",
    "We get the traditionnal MNIST dataset and add a new label to the existing one. For each digit we return a new label that stands for Odd or Even"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](https://upload.wikimedia.org/wikipedia/commons/2/27/MnistExamples.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = gluon.data.vision.MNIST(train=True)\n",
    "test_dataset = gluon.data.vision.MNIST(train=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transform(x,y):\n",
    "    x = x.transpose((2,0,1)).astype('float32')/255.\n",
    "    y1 = y\n",
    "    y2 = y % 2 #odd or even\n",
    "    return x, np.float32(y1), np.float32(y2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We assign the transform to the original dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset_t = train_dataset.transform(transform)\n",
    "test_dataset_t = test_dataset.transform(transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load the datasets DataLoaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = gluon.data.DataLoader(train_dataset_t, shuffle=True, last_batch='rollover', batch_size=batch_size, num_workers=5)\n",
    "test_data = gluon.data.DataLoader(test_dataset_t, shuffle=False, last_batch='rollover', batch_size=batch_size, num_workers=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input shape: (28, 28, 1), Target Labels: (5.0, 1.0)\n"
     ]
    }
   ],
   "source": [
    "print(\"Input shape: {}, Target Labels: {}\".format(train_dataset[0][0].shape, train_dataset_t[0][1:]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi-task Network\n",
    "\n",
    "The output of the featurization is passed to two different outputs layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiTaskNetwork(gluon.HybridBlock):\n",
    "    \n",
    "    def __init__(self):\n",
    "        super(MultiTaskNetwork, self).__init__()\n",
    "        \n",
    "        self.shared = gluon.nn.HybridSequential()\n",
    "        with self.shared.name_scope():\n",
    "            self.shared.add(\n",
    "                gluon.nn.Dense(128, activation='relu'),\n",
    "                gluon.nn.Dense(64, activation='relu'),\n",
    "                gluon.nn.Dense(10, activation='relu')\n",
    "            )\n",
    "        self.output1 = gluon.nn.Dense(10) # Digist recognition\n",
    "        self.output2 = gluon.nn.Dense(1) # odd or even\n",
    "\n",
    "        \n",
    "    def hybrid_forward(self, F, x):\n",
    "        y = self.shared(x)\n",
    "        output1 = self.output1(y)\n",
    "        output2 = self.output2(y)\n",
    "        return output1, output2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use two different losses, one for each output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_digits = gluon.loss.SoftmaxCELoss()\n",
    "loss_odd_even = gluon.loss.SigmoidBCELoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We create and initialize the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {},
   "outputs": [],
   "source": [
    "mx.random.seed(42)\n",
    "random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = MultiTaskNetwork()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.initialize(mx.init.Xavier(), ctx=ctx)\n",
    "net.hybridize() # hybridize for speed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':lr})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate Accuracy\n",
    "We need to evaluate the accuracy of each task separately"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_accuracy(net, data_iterator):\n",
    "    acc_digits = mx.metric.Accuracy(name='digits')\n",
    "    acc_odd_even = mx.metric.Accuracy(name='odd_even')\n",
    "    \n",
    "    for i, (data, label_digit, label_odd_even) in enumerate(data_iterator):\n",
    "        data = data.as_in_context(ctx)\n",
    "        label_digit = label_digit.as_in_context(ctx)\n",
    "        label_odd_even = label_odd_even.as_in_context(ctx).reshape(-1,1)\n",
    "\n",
    "        output_digit, output_odd_even = net(data)\n",
    "        \n",
    "        acc_digits.update(label_digit, output_digit.softmax())\n",
    "        acc_odd_even.update(label_odd_even, output_odd_even.sigmoid() > 0.5)\n",
    "    return acc_digits.get(), acc_odd_even.get()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Loop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to balance the contribution of each loss to the overall training and do so by tuning this alpha parameter within [0,1]."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.5 # Combine losses factor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [0], Acc Digits   0.8945 Loss Digits   0.3409\n",
      "Epoch [0], Acc Odd/Even 0.9561 Loss Odd/Even 0.1152\n",
      "Epoch [0], Testing Accuracies (('digits', 0.9487179487179487), ('odd_even', 0.9770633012820513))\n",
      "Epoch [1], Acc Digits   0.9576 Loss Digits   0.1475\n",
      "Epoch [1], Acc Odd/Even 0.9804 Loss Odd/Even 0.0559\n",
      "Epoch [1], Testing Accuracies (('digits', 0.9642427884615384), ('odd_even', 0.9826722756410257))\n",
      "Epoch [2], Acc Digits   0.9681 Loss Digits   0.1124\n",
      "Epoch [2], Acc Odd/Even 0.9852 Loss Odd/Even 0.0418\n",
      "Epoch [2], Testing Accuracies (('digits', 0.9580328525641025), ('odd_even', 0.9846754807692307))\n",
      "Epoch [3], Acc Digits   0.9734 Loss Digits   0.0961\n",
      "Epoch [3], Acc Odd/Even 0.9884 Loss Odd/Even 0.0340\n",
      "Epoch [3], Testing Accuracies (('digits', 0.9670472756410257), ('odd_even', 0.9839743589743589))\n",
      "Epoch [4], Acc Digits   0.9762 Loss Digits   0.0848\n",
      "Epoch [4], Acc Odd/Even 0.9894 Loss Odd/Even 0.0310\n",
      "Epoch [4], Testing Accuracies (('digits', 0.9652887658227848), ('odd_even', 0.9858583860759493))\n"
     ]
    }
   ],
   "source": [
    "for e in range(epochs):\n",
    "    # Accuracies for each task\n",
    "    acc_digits = mx.metric.Accuracy(name='digits')\n",
    "    acc_odd_even = mx.metric.Accuracy(name='odd_even')\n",
    "    # Accumulative losses\n",
    "    l_digits_ = 0.\n",
    "    l_odd_even_ = 0. \n",
    "    \n",
    "    for i, (data, label_digit, label_odd_even) in enumerate(train_data):\n",
    "        data = data.as_in_context(ctx)\n",
    "        label_digit = label_digit.as_in_context(ctx)\n",
    "        label_odd_even = label_odd_even.as_in_context(ctx).reshape(-1,1)\n",
    "        \n",
    "        with autograd.record():\n",
    "            output_digit, output_odd_even = net(data)\n",
    "            l_digits = loss_digits(output_digit, label_digit)\n",
    "            l_odd_even = loss_odd_even(output_odd_even, label_odd_even)\n",
    "\n",
    "            # Combine the loss of each task\n",
    "            l_combined = (1-alpha)*l_digits + alpha*l_odd_even\n",
    "            \n",
    "        l_combined.backward()\n",
    "        trainer.step(data.shape[0])\n",
    "        \n",
    "        l_digits_ += l_digits.mean()\n",
    "        l_odd_even_ += l_odd_even.mean()\n",
    "        acc_digits.update(label_digit, output_digit.softmax())\n",
    "        acc_odd_even.update(label_odd_even, output_odd_even.sigmoid() > 0.5)\n",
    "        \n",
    "    print(\"Epoch [{}], Acc Digits   {:.4f} Loss Digits   {:.4f}\".format(\n",
    "        e, acc_digits.get()[1], l_digits_.asscalar()/(i+1)))\n",
    "    print(\"Epoch [{}], Acc Odd/Even {:.4f} Loss Odd/Even {:.4f}\".format(\n",
    "        e, acc_odd_even.get()[1], l_odd_even_.asscalar()/(i+1)))\n",
    "    print(\"Epoch [{}], Testing Accuracies {}\".format(e, evaluate_accuracy(net, test_data)))\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_random_data():\n",
    "    idx = random.randint(0, len(test_dataset))\n",
    "\n",
    "    img = test_dataset[idx][0]\n",
    "    data, _, _ = test_dataset_t[idx]\n",
    "    data = data.as_in_context(ctx).expand_dims(axis=0)\n",
    "\n",
    "    plt.imshow(img.squeeze().asnumpy(), cmap='gray')\n",
    "    \n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted digit: [9.], odd: [1.]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADeVJREFUeJzt3X+MFPX9x/HXG6QGAQ3aiBdLpd9Ga6pBak5joqk01caaRuAfUhMbjE2viTUpEVFCNT31Dxu1rdWYJldLCk2/QhUb+KPWWuKP1jQNIKiotFJC00OEkjNBEiNyvPvHzdlTbz6zzs7uzPF+PpLL7e57Z+ad5V7M7H5m9mPuLgDxTKq7AQD1IPxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4I6oZsbMzNOJwQ6zN2tlee1tec3s6vM7O9mtsvMVrSzLgDdZWXP7TezyZL+IelKSYOSNku61t1fSyzDnh/osG7s+S+WtMvdd7v7EUlrJS1oY30Auqid8J8p6d9j7g9mj32ImfWZ2RYz29LGtgBUrOMf+Ln7gKQBicN+oEna2fPvlTR7zP3PZI8BmADaCf9mSWeb2efM7FOSvilpYzVtAei00of97n7UzG6S9JSkyZJWufurlXUGoKNKD/WV2hjv+YGO68pJPgAmLsIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCKj1FtySZ2R5J70galnTU3XuraApA57UV/sxX3P1gBesB0EUc9gNBtRt+l/RHM9tqZn1VNASgO9o97L/M3fea2emSnjazne7+/NgnZP8p8B8D0DDm7tWsyKxf0mF3vz/xnGo2BiCXu1srzyt92G9m08xsxuhtSV+TtKPs+gB0VzuH/bMk/c7MRtfz/+7+h0q6AtBxlR32t7QxDvuBjuv4YT+AiY3wA0ERfiAowg8ERfiBoAg/EFQVV/WhwaZPn56sL1++vK3lb7755mT97bffzq3deeedyWUffvjhZP3o0aPJOtLY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUFzSOwFMnTo1WV+xYkVurWgcftq0acl69n0NuTr591M0zr9s2bJk/ciRI1W2M2FwSS+AJMIPBEX4gaAIPxAU4QeCIvxAUIQfCIpx/i4oGqe//PLLk/Vbb701WZ8/f/4nballQ0NDbdWnTJmSWzvrrLNK9TTqySefTNafe+653NoDDzyQXHYinyPAOD+AJMIPBEX4gaAIPxAU4QeCIvxAUIQfCKpwnN/MVkn6hqQD7n5+9tipktZJmiNpj6TF7p7/Be3/W9dxOc5/0kknJesPPvhgsn7DDTdU2c6H7NixI1m/5557kvVt27Yl6zt37kzWZ8yYkVt76qmnkstecsklyXo7zjnnnGR9165dHdt2p1U5zv8rSVd95LEVkja5+9mSNmX3AUwgheF39+clffQ0rgWSVme3V0taWHFfADqs7Hv+We6+L7v9lqRZFfUDoEvanqvP3T31Xt7M+iT1tbsdANUqu+ffb2Y9kpT9PpD3RHcfcPded+8tuS0AHVA2/BslLcluL5G0oZp2AHRLYfjN7FFJf5X0BTMbNLNvS/qRpCvN7A1JV2T3AUwghe/53f3anNJXK+5lwrriiiuS9XbH8Q8ePJisr1u3Lrd2yy23JJd97733SvXUqp6entq2jTTO8AOCIvxAUIQfCIrwA0ERfiAowg8E1fbpvVGkprJevnx5R7f9yCOPJOsrV67s2LZPOCH9J7Jo0aJk/aGHHsqtnX766aV6atUzzzyTW9u7d29Htz0RsOcHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAY52/RHXfckVu79NJL21p30Tj+3Xff3db6U84999xkfenSpcl6X19zv6Ht3nvvza29++67XeykmdjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPO3qJPXnq9ZsyZZLxqTTk03XTROv3jx4mT9tNNOS9aLpnjvpNR3BUjSs88+251GJij2/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QVOE4v5mtkvQNSQfc/fzssX5J35H0n+xpK939951qsgk2b96cW7v++uvbWveGDRuS9SNHjiTrU6dOza2dfPLJpXoa9f777yfr1113XbKemlNg7ty5pXoa9dhjjyXrTAGe1sqe/1eSrhrn8Z+6+7zs57gOPnA8Kgy/uz8vaagLvQDoonbe899kZi+b2Sozm1lZRwC6omz4fy7p85LmSdon6cd5TzSzPjPbYmZbSm4LQAeUCr+773f3YXc/JukXki5OPHfA3XvdvbdskwCqVyr8ZtYz5u4iSTuqaQdAt7Qy1PeopPmSPm1mg5J+KGm+mc2T5JL2SPpuB3sE0AHWzeuxzay+i7/bNGlS/kHS448/nlx24cKFVbdTmRdeeCFZv+uuu5L1ovMIisbiU4p6mz9/frI+PDxcetsTmbtbK8/jDD8gKMIPBEX4gaAIPxAU4QeCIvxAUHx1d4uOHTuWW7vxxhuTy+7fvz9ZL7osdufOncn6E088kVsr+nrrw4cPJ+snnnhisl40HGeWP+qUek0ladOmTcl61KG8qrDnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGguKQXSWeccUay/uabb5Ze9/bt25P1Cy+8sPS6I+OSXgBJhB8IivADQRF+ICjCDwRF+IGgCD8QFNfzI6m/v7+t5VNTfK9du7atdaM97PmBoAg/EBThB4Ii/EBQhB8IivADQRF+IKjC6/nNbLakNZJmSXJJA+7+MzM7VdI6SXMk7ZG02N3fLlgX1/M3zKJFi5L11JwAklT093Pffffl1m677bbksiinyuv5j0pa5u5flHSJpO+Z2RclrZC0yd3PlrQpuw9ggigMv7vvc/cXs9vvSHpd0pmSFkhanT1ttaSFnWoSQPU+0Xt+M5sj6UuS/iZplrvvy0pvaeRtAYAJouVz+81suqT1kpa6+6Gxc7C5u+e9nzezPkl97TYKoFot7fnNbIpGgv8bdx/9BGi/mfVk9R5JB8Zb1t0H3L3X3XuraBhANQrDbyO7+F9Ket3dfzKmtFHSkuz2Ekkbqm8PQKe0MtR3maQ/S3pF0uicyis18r7/t5I+K+lfGhnqGypYF0N9DfPSSy8l63Pnzk3Wh4aS/+S64IILcmuDg4PJZVFOq0N9he/53f0vkvJW9tVP0hSA5uAMPyAowg8ERfiBoAg/EBThB4Ii/EBQfHX3ca7ostnzzjsvWR8eHk7Wb7/99mSdsfzmYs8PBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0EVXs9f6ca4nr8j5syZk1vbtm1bctlTTjklWd+6dWuyftFFFyXr6L4qv7obwHGI8ANBEX4gKMIPBEX4gaAIPxAU4QeC4nr+48DSpUtza0Xj+EX6+/vbWh7NxZ4fCIrwA0ERfiAowg8ERfiBoAg/EBThB4IqvJ7fzGZLWiNpliSXNODuPzOzfknfkfSf7Kkr3f33Beviev4SrrnmmmR9/fr1ubXJkye3te1Jk9g/TDStXs/fykk+RyUtc/cXzWyGpK1m9nRW+6m731+2SQD1KQy/u++TtC+7/Y6ZvS7pzE43BqCzPtExnZnNkfQlSX/LHrrJzF42s1VmNjNnmT4z22JmW9rqFEClWg6/mU2XtF7SUnc/JOnnkj4vaZ5Gjgx+PN5y7j7g7r3u3ltBvwAq0lL4zWyKRoL/G3d/QpLcfb+7D7v7MUm/kHRx59oEULXC8JuZSfqlpNfd/SdjHu8Z87RFknZU3x6ATmnl0/5LJX1L0itmtj17bKWka81snkaG//ZI+m5HOoR2796drB86dCi3NnPmuB/FfOD++xmsiaqVT/v/Imm8ccPkmD6AZuMMDiAowg8ERfiBoAg/EBThB4Ii/EBQTNENHGeYohtAEuEHgiL8QFCEHwiK8ANBEX4gKMIPBNXtKboPSvrXmPufzh5roqb21tS+JHorq8rezmr1iV09yedjGzfb0tTv9mtqb03tS6K3surqjcN+ICjCDwRVd/gHat5+SlN7a2pfEr2VVUtvtb7nB1Cfuvf8AGpSS/jN7Coz+7uZ7TKzFXX0kMfM9pjZK2a2ve4pxrJp0A6Y2Y4xj51qZk+b2RvZ7/R3c3e3t34z25u9dtvN7OqaepttZs+Y2Wtm9qqZfT97vNbXLtFXLa9b1w/7zWyypH9IulLSoKTNkq5199e62kgOM9sjqdfdax8TNrMvSzosaY27n589dq+kIXf/UfYf50x3v60hvfVLOlz3zM3ZhDI9Y2eWlrRQ0vWq8bVL9LVYNbxudez5L5a0y913u/sRSWslLaihj8Zz9+clDX3k4QWSVme3V2vkj6frcnprBHff5+4vZrffkTQ6s3Str12ir1rUEf4zJf17zP1BNWvKb5f0RzPbamZ9dTczjlnZtOmS9JakWXU2M47CmZu76SMzSzfmtSsz43XV+MDv4y5z9wslfV3S97LD20bykfdsTRquaWnm5m4ZZ2bpD9T52pWd8bpqdYR/r6TZY+5/JnusEdx9b/b7gKTfqXmzD+8fnSQ1+32g5n4+0KSZm8ebWVoNeO2aNON1HeHfLOlsM/ucmX1K0jclbayhj48xs2nZBzEys2mSvqbmzT68UdKS7PYSSRtq7OVDmjJzc97M0qr5tWvcjNfu3vUfSVdr5BP/f0r6QR095PT1f5Jeyn5erbs3SY9q5DDwfY18NvJtSadJ2iTpDUl/knRqg3r7taRXJL2skaD11NTbZRo5pH9Z0vbs5+q6X7tEX7W8bpzhBwTFB35AUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4L6L4bahh5ke9v1AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "data = get_random_data()\n",
    "\n",
    "digit, odd_even = net(data)\n",
    "\n",
    "digit = digit.argmax(axis=1)[0].asnumpy()\n",
    "odd_even = (odd_even.sigmoid()[0] > 0.5).asnumpy()\n",
    "\n",
    "print(\"Predicted digit: {}, odd: {}\".format(digit, odd_even))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
